{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from xlnet import xlnet\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "from xlnet import model_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/xlnet-base-29-03-2020.tar.gz\n",
    "# !tar -zxf xlnet-base-29-03-2020.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.json\t\t\t       model.ckpt-300000.meta\r\n",
      "model.ckpt-300000.data-00000-of-00001  sp10m.cased.v9.model\r\n",
      "model.ckpt-300000.index\t\t       sp10m.cased.v9.vocab\r\n"
     ]
    }
   ],
   "source": [
    "!ls xlnet-base-29-03-2020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "from xlnet.prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('xlnet-base-29-03-2020/sp10m.cased.v9.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('dataset.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower= False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from unidecode import unidecode\n",
    "\n",
    "def cleaning(string):\n",
    "    string = unidecode(string)\n",
    "    string = re.sub(r'\\w+:\\/{2}[\\d\\w-]+(\\.[\\d\\w-]+)*(?:(?:\\/[^\\s/]*))*', '', string)\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip().split()\n",
    "    string = [w for w in string if w[0] != '@']\n",
    "    return ' '.join(string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 51761/51761 [00:17<00:00, 2903.89it/s]\n",
      "100%|██████████| 5029/5029 [00:01<00:00, 2792.18it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for i in tqdm(range(len(train_X))):\n",
    "    train_X[i] = cleaning(train_X[i])\n",
    "    \n",
    "for i in tqdm(range(len(test_X))):\n",
    "    test_X[i] = cleaning(test_X[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 510\n",
    "\n",
    "def get_tokenize(strings):\n",
    "    input_ids, input_masks, segment_ids = [], [], []\n",
    "    for text in tqdm(strings):\n",
    "        tokens_a = tokenize_fn(text)\n",
    "        if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "            tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "        \n",
    "        segment_id = [SEG_ID_A] * len(tokens_a)\n",
    "        tokens_a.append(SEP_ID)\n",
    "        tokens_a.append(CLS_ID)\n",
    "        segment_id.append(SEG_ID_A)\n",
    "        segment_id.append(SEG_ID_CLS)\n",
    "        input_mask = [0] * len(tokens_a)\n",
    "        \n",
    "        input_ids.append(tokens_a)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "\n",
    "    return input_ids, input_masks, segment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 51761/51761 [03:37<00:00, 238.44it/s]\n",
      "100%|██████████| 5029/5029 [00:23<00:00, 213.83it/s]\n"
     ]
    }
   ],
   "source": [
    "train_input_ids, train_input_masks, train_segment_ids = get_tokenize(train_X)\n",
    "test_input_ids, test_input_masks, test_segment_ids = get_tokenize(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/xlnet.py:64: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.1,\n",
    "      dropatt=0.1,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='xlnet-base-29-03-2020/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64701 6470\n"
     ]
    }
   ],
   "source": [
    "epoch = 10\n",
    "batch_size = 8\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_X) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "print(num_train_steps, num_warmup_steps)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "training_parameters = dict(\n",
    "      decay_method = 'poly',\n",
    "      train_steps = num_train_steps,\n",
    "      learning_rate = learning_rate,\n",
    "      warmup_steps = num_warmup_steps,\n",
    "      min_lr_ratio = 0.0,\n",
    "      weight_decay = 0.00,\n",
    "      adam_epsilon = 1e-8,\n",
    "      num_core_per_host = 1,\n",
    "      lr_layer_decay_rate = 1,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clip = 1.0,\n",
    "      clamp_len=-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, \n",
    "                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,\n",
    "                min_lr_ratio, clip, **kwargs):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "        \n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        \n",
    "        output_layer = xlnet_model.get_sequence_output()\n",
    "        output_layer = tf.transpose(output_layer, [1, 0, 2])\n",
    "        \n",
    "        self.logits_seq = tf.layers.dense(output_layer, dimension_output)\n",
    "        self.logits_seq = tf.identity(self.logits_seq, name = 'logits_seq')\n",
    "        self.logits = self.logits_seq[:, 0]\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer, self.learning_rate, _ = model_utils.get_train_op(training_parameters, self.cost)\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/xlnet.py:221: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/xlnet.py:221: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/modeling.py:453: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/modeling.py:460: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/modeling.py:535: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/modeling.py:67: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:96: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:108: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:123: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:131: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'xlnet-base-29-03-2020/model.ckpt-300000'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from xlnet-base-29-03-2020/model.ckpt-300000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.5, 2.7506523]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences\n",
    "\n",
    "i = 0\n",
    "index = 4\n",
    "batch_x = pad_sequences(train_input_ids[i : index], padding='post')\n",
    "batch_y = train_Y[i : index]\n",
    "batch_masks = pad_sequences(train_input_masks[i : index], padding='post', value = 1)\n",
    "batch_segments = pad_sequences(train_segment_ids[i : index], padding='post', value = 4)\n",
    "\n",
    "sess.run(\n",
    "    [model.accuracy, model.cost],\n",
    "    feed_dict = {\n",
    "        model.X: batch_x,\n",
    "        model.Y: batch_y,\n",
    "        model.segment_ids: batch_segments,\n",
    "        model.input_masks: batch_masks,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:16:09<00:00,  1.42it/s, accuracy=0, cost=5.63]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=0.0108]    \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.858704\n",
      "time taken: 4734.145400762558\n",
      "epoch: 0, training loss: 0.640360, training acc: 0.753438, valid loss: 0.442873, valid acc: 0.858704\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:36<00:00,  1.43it/s, accuracy=0, cost=10.4]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=0.0293]    \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.858704, current acc: 0.884142\n",
      "time taken: 4700.415242433548\n",
      "epoch: 1, training loss: 0.355910, training acc: 0.882437, valid loss: 0.491568, valid acc: 0.884142\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:36<00:00,  1.43it/s, accuracy=0, cost=9.13]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=0.00359]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.884142, current acc: 0.896065\n",
      "time taken: 4701.144730329514\n",
      "epoch: 2, training loss: 0.256305, training acc: 0.931038, valid loss: 0.510110, valid acc: 0.896065\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:25<00:00,  1.43it/s, accuracy=0, cost=10.3]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=0.00068]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, pass acc: 0.896065, current acc: 0.899444\n",
      "time taken: 4690.006617069244\n",
      "epoch: 3, training loss: 0.179855, training acc: 0.958488, valid loss: 0.566977, valid acc: 0.899444\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:38<00:00,  1.43it/s, accuracy=0, cost=7.23]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=0.00013]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.899444, current acc: 0.904014\n",
      "time taken: 4702.562220573425\n",
      "epoch: 4, training loss: 0.123122, training acc: 0.973227, valid loss: 0.692606, valid acc: 0.904014\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:38<00:00,  1.43it/s, accuracy=0, cost=6.2]       \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=7.02e-5]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, pass acc: 0.904014, current acc: 0.909579\n",
      "time taken: 4702.606304883957\n",
      "epoch: 5, training loss: 0.087910, training acc: 0.981533, valid loss: 0.726763, valid acc: 0.909579\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:38<00:00,  1.43it/s, accuracy=0, cost=4.07]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.82it/s, accuracy=1, cost=0.000132] \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, pass acc: 0.909579, current acc: 0.916137\n",
      "time taken: 4702.851442337036\n",
      "epoch: 6, training loss: 0.066086, training acc: 0.986266, valid loss: 0.629885, valid acc: 0.916137\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:37<00:00,  1.43it/s, accuracy=0, cost=4.51]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=1.98e-6]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, pass acc: 0.916137, current acc: 0.916733\n",
      "time taken: 4701.885998725891\n",
      "epoch: 7, training loss: 0.047587, training acc: 0.990013, valid loss: 0.753258, valid acc: 0.916733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:15:38<00:00,  1.43it/s, accuracy=1, cost=0.0185]   \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.82it/s, accuracy=1, cost=9.54e-8]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, pass acc: 0.916733, current acc: 0.918323\n",
      "time taken: 4702.687706708908\n",
      "epoch: 8, training loss: 0.034691, training acc: 0.991945, valid loss: 0.843098, valid acc: 0.918323\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:  72%|███████▏  | 4637/6471 [54:10<21:25,  1.43it/s, accuracy=1, cost=2.02e-5]    IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "test minibatch loop: 100%|██████████| 629/629 [02:44<00:00,  3.83it/s, accuracy=1, cost=3.34e-5]  \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, pass acc: 0.918323, current acc: 0.921701\n",
      "time taken: 4701.4859120845795\n",
      "epoch: 9, training loss: 0.024727, training acc: 0.993877, valid loss: 0.826608, valid acc: 0.921701\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:  23%|██▎       | 1456/6471 [17:01<58:31,  1.43it/s, accuracy=1, cost=0.0619]     IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  82%|████████▏ | 5330/6471 [1:02:18<13:02,  1.46it/s, accuracy=1, cost=2.18e-5]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = pad_sequences(train_input_ids[i : index], padding='post')\n",
    "        batch_y = train_Y[i : index]\n",
    "        batch_masks = pad_sequences(train_input_masks[i : index], padding='post', value = 1)\n",
    "        batch_segments = pad_sequences(train_segment_ids[i : index], padding='post', value = 4)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = pad_sequences(test_input_ids[i : index], padding='post')\n",
    "        batch_y = test_Y[i : index]\n",
    "        batch_masks = pad_sequences(test_input_masks[i : index], padding='post', value = 1)\n",
    "        batch_segments = pad_sequences(test_segment_ids[i : index], padding='post', value = 4)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "        \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xlnet-base-relevancy/model.ckpt'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'xlnet-base-relevancy/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = dict(\n",
    "      is_training=False,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='xlnet-base-29-03-2020/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from xlnet-base-relevancy/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.restore(sess, 'xlnet-base-relevancy/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm as tqdm_base\n",
    "def tqdm(*args, **kwargs):\n",
    "    if hasattr(tqdm_base, '_instances'):\n",
    "        for instance in list(tqdm_base._instances):\n",
    "            tqdm_base._decr_instances(instance)\n",
    "    return tqdm_base(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test minibatch loop: 100%|██████████| 629/629 [02:20<00:00,  4.47it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_input_ids))\n",
    "    batch_x = pad_sequences(test_input_ids[i : index], padding='post')\n",
    "    batch_y = test_Y[i : index]\n",
    "    batch_masks = pad_sequences(test_input_masks[i : index], padding='post', value = 1)\n",
    "    batch_segments = pad_sequences(test_segment_ids[i : index], padding='post', value = 4)\n",
    "    predict_Y += np.argmax(sess.run(model.logits,\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "    ), 1, ).tolist()\n",
    "    real_Y += batch_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "not relevant    0.92476   0.88945   0.90676      1990\n",
      "    relevant    0.92937   0.95262   0.94085      3039\n",
      "\n",
      "    accuracy                        0.92762      5029\n",
      "   macro avg    0.92707   0.92103   0.92381      5029\n",
      "weighted avg    0.92755   0.92762   0.92736      5029\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, predict_Y, target_names = ['not relevant', 'relevant'],\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'Placeholder_3',\n",
       " 'model/transformer/r_w_bias',\n",
       " 'model/transformer/r_r_bias',\n",
       " 'model/transformer/word_embedding/lookup_table',\n",
       " 'model/transformer/r_s_bias',\n",
       " 'model/transformer/seg_embed',\n",
       " 'model/transformer/layer_0/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_0/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_0/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_0/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_0/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_0/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_0/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_0/ff/layer_1/bias',\n",
       " 'model/transformer/layer_0/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_0/ff/layer_2/bias',\n",
       " 'model/transformer/layer_0/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_1/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_1/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_1/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_1/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_1/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_1/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_1/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_1/ff/layer_1/bias',\n",
       " 'model/transformer/layer_1/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_1/ff/layer_2/bias',\n",
       " 'model/transformer/layer_1/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_2/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_2/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_2/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_2/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_2/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_2/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_2/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_2/ff/layer_1/bias',\n",
       " 'model/transformer/layer_2/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_2/ff/layer_2/bias',\n",
       " 'model/transformer/layer_2/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_3/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_3/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_3/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_3/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_3/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_3/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_3/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_3/ff/layer_1/bias',\n",
       " 'model/transformer/layer_3/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_3/ff/layer_2/bias',\n",
       " 'model/transformer/layer_3/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_4/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_4/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_4/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_4/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_4/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_4/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_4/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_4/ff/layer_1/bias',\n",
       " 'model/transformer/layer_4/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_4/ff/layer_2/bias',\n",
       " 'model/transformer/layer_4/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_5/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_5/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_5/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_5/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_5/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_5/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_5/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_5/ff/layer_1/bias',\n",
       " 'model/transformer/layer_5/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_5/ff/layer_2/bias',\n",
       " 'model/transformer/layer_5/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_6/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_6/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_6/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_6/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_6/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_6/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_6/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_6/ff/layer_1/bias',\n",
       " 'model/transformer/layer_6/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_6/ff/layer_2/bias',\n",
       " 'model/transformer/layer_6/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_7/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_7/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_7/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_7/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_7/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_7/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_7/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_7/ff/layer_1/bias',\n",
       " 'model/transformer/layer_7/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_7/ff/layer_2/bias',\n",
       " 'model/transformer/layer_7/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_8/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_8/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_8/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_8/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_8/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_8/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_8/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_8/ff/layer_1/bias',\n",
       " 'model/transformer/layer_8/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_8/ff/layer_2/bias',\n",
       " 'model/transformer/layer_8/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_9/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_9/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_9/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_9/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_9/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_9/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_9/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_9/ff/layer_1/bias',\n",
       " 'model/transformer/layer_9/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_9/ff/layer_2/bias',\n",
       " 'model/transformer/layer_9/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_10/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_10/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_10/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_10/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_10/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_10/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_10/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_10/ff/layer_1/bias',\n",
       " 'model/transformer/layer_10/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_10/ff/layer_2/bias',\n",
       " 'model/transformer/layer_10/ff/LayerNorm/gamma',\n",
       " 'model/transformer/layer_11/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_11/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_11/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_11/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_11/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_11/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_11/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_11/ff/layer_1/bias',\n",
       " 'model/transformer/layer_11/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_11/ff/layer_2/bias',\n",
       " 'model/transformer/layer_11/ff/LayerNorm/gamma',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'logits_seq',\n",
       " 'logits']"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from xlnet-base-relevancy/model.ckpt\n",
      "INFO:tensorflow:Froze 163 variables.\n",
      "INFO:tensorflow:Converted 163 variables to const ops.\n",
      "7673 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('xlnet-base-relevancy', strings)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
